import torch
import pandas as pd
import numpy as np
import h5py
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

class TimeSeriesDataset(Dataset):
    def __init__(self, data_path, dataset_name, seq_len=24, max_days_per_user=None):
        self.dataset_name = dataset_name
        self.seq_len = seq_len
        
        if dataset_name == 'electricity':
            self._load_electricity_data(data_path, max_days_per_user)
        elif dataset_name in ['metr-la', 'pems-bay']:
            self._load_h5_data(data_path, max_days_per_user)
        elif dataset_name == 'harry_potter':
            self._load_text_data(data_path)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    def _load_electricity_data(self, data_path, max_days_per_user):
        df = pd.read_csv(data_path)
        
        timestamp_col = None
        for col in ['timestamp', 'date', 'time']:
            if col in df.columns:
                timestamp_col = col
                break
        
        if timestamp_col is None:
            timestamp_col = df.columns[0]
        
        if max_days_per_user:
            max_rows = max_days_per_user * 24
            df = df.head(max_rows)
        
        data_cols = [col for col in df.columns if col != timestamp_col]
        all_data = df[data_cols].values.astype(np.float32)
        
        self.input_data = all_data[:, :-1]
        self.target_data = all_data[:, -1]
        self._standardize_and_create_sequences()
    
    def _load_h5_data(self, data_path, max_days_per_user):
        if 'metr-la' in data_path.lower():
            data_key = 'df'
            nested_key = 'block0_values'
        elif 'pems-bay' in data_path.lower():
            data_key = 'speed'
            nested_key = None
        else:
            data_key = 'data'
            nested_key = None
        
        with h5py.File(data_path, 'r') as f:
            print(f"Available keys in H5 file: {list(f.keys())}")
            
            try:
                if nested_key and data_key in f and nested_key in f[data_key]:
                    all_data = f[data_key][nested_key][:]
                    print(f"Loaded data from {data_key}/{nested_key} with shape {all_data.shape}")
                elif data_key in f:
                    data_obj = f[data_key]
                    if hasattr(data_obj, 'shape'):
                        all_data = data_obj[:]
                        print(f"Loaded data from {data_key} with shape {all_data.shape}")
                    else:
                        print(f"'{data_key}' is a group with keys: {list(data_obj.keys())}")
                        best_dataset = None
                        best_size = 0
                        for sub_key in data_obj.keys():
                            sub_data = data_obj[sub_key]
                            if hasattr(sub_data, 'shape') and len(sub_data.shape) == 2:
                                size = sub_data.shape[0] * sub_data.shape[1]
                                if size > best_size:
                                    best_dataset = sub_key
                                    best_size = size
                        
                        if best_dataset:
                            all_data = data_obj[best_dataset][:]
                            print(f"Loaded data from {data_key}/{best_dataset} with shape {all_data.shape}")
                        else:
                            raise ValueError(f"No suitable dataset found in group '{data_key}'")
                else:
                    fallback_keys = ['data', 'dataset', 'values', 'X', 'features']
                    data_found = False
                    
                    for fallback_key in fallback_keys:
                        if fallback_key in f:
                            all_data = f[fallback_key][:]
                            print(f"Loaded data from fallback key '{fallback_key}' with shape {all_data.shape}")
                            data_found = True
                            break
                    
                    if not data_found:
                        first_key = list(f.keys())[0]
                        print(f"Using first available key: {first_key}")
                        data_obj = f[first_key]
                        if hasattr(data_obj, 'shape'):
                            all_data = data_obj[:]
                        else:
                            first_sub_key = list(data_obj.keys())[0]
                            all_data = data_obj[first_sub_key][:]
                        print(f"Loaded data with shape {all_data.shape}")
                        
            except Exception as e:
                print(f"Error loading data with key '{data_key}': {e}")
                print("Attempting to find any suitable 2D dataset...")
                
                def find_datasets(group, path=""):
                    datasets = []
                    for key in group.keys():
                        item_path = f"{path}/{key}" if path else key
                        item = group[key]
                        if hasattr(item, 'shape'):
                            datasets.append((item_path, item.shape, item))
                        elif hasattr(item, 'keys'):
                            datasets.extend(find_datasets(item, item_path))
                    return datasets
                
                datasets = find_datasets(f)
                best_dataset = None
                best_size = 0
                
                for path, shape, dataset in datasets:
                    if len(shape) == 2:
                        size = shape[0] * shape[1]
                        if size > best_size:
                            best_dataset = (path, shape, dataset)
                            best_size = size
                
                if best_dataset:
                    path, shape, dataset = best_dataset
                    all_data = dataset[:]
                    print(f"Found and loaded dataset '{path}' with shape {shape}")
                else:
                    raise ValueError(f"Could not find any suitable 2D dataset in {data_path}")
        
        print(f"Final data shape: {all_data.shape}")
        
        if max_days_per_user:
            max_rows = max_days_per_user * 24
            all_data = all_data[:max_rows]
            print(f"Limited to {max_days_per_user} days: {all_data.shape}")
        
        self.input_data = all_data[:, :-1].astype(np.float32)
        self.target_data = all_data[:, -1].astype(np.float32)
        
        self._standardize_and_create_sequences()
    
    def _load_text_data(self, data_path):
        with open(data_path, 'r', encoding='gbk') as f:
            text = f.read().replace("　", " ").replace("\n", " ")
        
        self.chars = sorted(list(set(text)))
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.vocab_size = len(self .chars)
        
        self.char_indices = [self.char_to_idx[ch] for ch in text]
        
        self.sequences = []
        self.targets = []
        
        for i in range(len(self.char_indices) - self.seq_len):
            seq = self.char_indices[i:i + self.seq_len]
            target = self.char_indices[i + 1:i + self.seq_len + 1]
            self.sequences.append(seq)
            self.targets.append(target)
        
        self.sequences = np.array(self.sequences)
        self.targets = np.array(self.targets)
        
        self.num_features = self.vocab_size
    
    def _standardize_and_create_sequences(self):
        num_timesteps, num_features = self.input_data.shape
        
        self.normalized_input_data = np.zeros_like(self.input_data)
        for feature_idx in range(num_features):
            feature_series = self.input_data[:, feature_idx]
            feature_mean = np.mean(feature_series)
            feature_std = np.std(feature_series)
            if feature_std > 0:
                self.normalized_input_data[:, feature_idx] = (feature_series - feature_mean) / feature_std
            else:
                self.normalized_input_data[:, feature_idx] = feature_series - feature_mean
        
        target_mean = np.mean(self.target_data)
        target_std = np.std(self.target_data)
        if target_std > 0:
            self.normalized_target_data = (self.target_data - target_mean) / target_std
        else:
            self.normalized_target_data = self.target_data - target_mean
        
        self.sequences = []
        self.targets = []
        
        for t in range(num_timesteps - self.seq_len + 1):
            input_seq = self.normalized_input_data[t:t + self.seq_len, :]
            target_seq = self.normalized_target_data[t:t + self.seq_len].reshape(-1, 1)
            
            self.sequences.append(input_seq)
            self.targets.append(target_seq)
        
        self.sequences = np.array(self.sequences)
        self.targets = np.array(self.targets)
        self.num_features = num_features
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        if self.dataset_name == 'harry_potter':
            input_seq = np.eye(self.vocab_size)[self.sequences[idx]]
            target_seq = np.eye(self.vocab_size)[self.targets[idx]]
            return torch.FloatTensor(input_seq), torch.FloatTensor(target_seq)
        else:
            input_seq = self.sequences[idx]
            target_seq = self.targets[idx]
            return torch.FloatTensor(input_seq), torch.FloatTensor(target_seq)

def get_dataset(args):
    dataset_name = args.dataset.lower()
    
    if dataset_name == 'electricity':
        data_path = getattr(args, 'data_path', './data/electricity.csv')
    elif dataset_name == 'metr-la':
        data_path = getattr(args, 'data_path', './data/metr-la.h5')
    elif dataset_name == 'pems-bay':
        data_path = getattr(args, 'data_path', './data/pems-bay.h5')
    elif dataset_name == 'harry_potter':
        data_path = getattr(args, 'data_path', './data/harrypotter.txt')
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    seq_len = getattr(args, 'seq_len', 24)
    max_days_per_user = getattr(args, 'max_days_per_user', None)
    
    dataset = TimeSeriesDataset(
        data_path=data_path,
        dataset_name=dataset_name,
        seq_len=seq_len,
        max_days_per_user=max_days_per_user
    )
    
    total_samples = len(dataset)
    train_ratio = getattr(args, 'train_ratio', 0.8)
    val_ratio = 1.0 - train_ratio
    
    train_size = int(total_samples * train_ratio)
    val_size = total_samples - train_size
    
    train_indices = list(range(train_size))
    val_indices = list(range(train_size, train_size + val_size))
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    batch_size = getattr(args, 'batchsize', 32)
    num_workers = getattr(args, 'num_workers', 0)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    
    train_loader.dataset.dataset.num_features = dataset.num_features
    
    print(f"Dataset: {dataset_name}")
    print(f"Train: {len(train_dataset)} ({train_ratio:.1%}) | Val: {len(val_dataset)} ({val_ratio:.1%})")
    
    sample_input, sample_target = dataset[0]
    print(f"Input shape: {sample_input.shape} | Target shape: {sample_target.shape}")
    
    return train_loader, val_loader

if __name__ == "__main__":
    import argparse
    
    datasets_to_test = ['electricity', 'metr-la', 'pems-bay', 'harry_potter']
    
    for dataset_name in datasets_to_test:
        print(f"\n{'='*50}")
        print(f"Testing {dataset_name} dataset")
        print(f"{'='*50}")
        
        try:
            args = argparse.Namespace()
            args.dataset = dataset_name
            args.seq_len = 24 if dataset_name != 'harry_potter' else 50
            args.batchsize = 16
            args.train_ratio = 0.8
            args.num_workers = 0
            args.max_days_per_user = 5 if dataset_name in ['electricity', 'metr-la', 'pems-bay'] else None
            
            train_loader, val_loader = get_dataset(args)
            
            for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
                print(f"Batch shapes - Input: {input_seq.shape} | Target: {target_seq.shape}")
                break
                
            print(f"✅ {dataset_name} dataset loaded successfully!")
            
        except Exception as e:
            print(f"❌ Error testing {dataset_name}: {e}")